{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# TVM Relay Pass\n",
    "\n",
    "参考：\n",
    "\n",
    "1. {func}`tvm.transform.module_pass`\n",
    "1. {func}`tvm.relay.transform.function_pass`\n",
    "1. {func}`tvm.instrument.pass_instrument`\n",
    "1. [pass_infra](https://daobook.github.io/tvm/docs/arch/pass_infra.html)\n",
    "5. [如何使用 TVM Pass Infra](https://daobook.github.io/tvm/docs/how_to/extend_tvm/use_pass_infra.html)\n",
    "\n",
    "Relay/tir 程序的优化可以应用在不同的粒度上，即函数级 {class}`tvm.relay.transform.FunctionPass`/{class}`tvm.tir.transform.PrimFuncPass` 和模块级 {class}`tvm.transform.ModulePass`。或者用户可以依赖于 {class}`tvm.transform.Sequential` 在 Relay/tir 程序上应用 pass 序列，其中 pass 之间的依赖性可以由 pass infra 解析。\n",
    "\n",
    "## 函数级 Pass\n",
    "\n",
    "当提供 `pass_func` 时，{func}`~tvm.relay.transform.function_pass` 函数返回回调函数。否则，它将使用给定的优化函数返回创建的函数级 pass。\n",
    "\n",
    "参数：\n",
    "\n",
    "- `pass_func`：变换函数或变换类。\n",
    "- `opt_level`：优化级别。\n",
    "- `name`：pass 的名称。名称可以为空。\n",
    "- `required`：依赖的 pass 列表。\n",
    "\n",
    "直接看例子。\n",
    "\n",
    "创建函数 `func`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @main(%a: Tensor[(10, 20), float32], %b: Tensor[(10, 20), float32]) {\n",
       "  %0 = add(%a, %b);\n",
       "  multiply(%0, %0)\n",
       "}"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.ir import IRModule\n",
    "\n",
    "\n",
    "a = relay.var(\"a\", shape=(10, 20))\n",
    "b = relay.var(\"b\", shape=(10, 20))\n",
    "c = a + b\n",
    "d = c * c\n",
    "func = relay.Function([a, b], d)\n",
    "input_mod = IRModule.from_expr(func)\n",
    "input_mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建函数级的 pass："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@relay.transform.function_pass(opt_level=1)\n",
    "class TestReplaceFunc:\n",
    "    def __init__(self, new_func):\n",
    "        self.new_func = new_func\n",
    "\n",
    "    def transform_function(self, func, mod, ctx):\n",
    "        # 为了演示，将 func 转换为 new_func\n",
    "        return self.new_func"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "此 pass 仅将 `func` 变换为 `new_func`。\n",
    "\n",
    "创建 `new_func`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = relay.var(\"x\", shape=(10, 20))\n",
    "new_func = relay.Function([x], relay.log(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`fpass` 是特殊的 pass，它将每个函数替换为 `new_func`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "fpass = TestReplaceFunc(new_func)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在，`input_mod` 中的每个函数都被 `new_func` 替换："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @main(%x: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {\n",
       "  log(%x) /* ty=Tensor[(10, 20), float32] */\n",
       "}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_mod = fpass(input_mod)\n",
    "res_mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "也可以通过装饰用户定义的 `transform` 函数来创建函数级 pass："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "@relay.transform.function_pass(opt_level=2)\n",
    "def transform(func, mod, ctx):\n",
    "    # 自定义变换\n",
    "    x = relay.var(\"x\", shape=(10, 20))\n",
    "    new_func = relay.Function([x], relay.log(x))\n",
    "    return new_func\n",
    "\n",
    "function_pass = transform\n",
    "assert isinstance(function_pass, relay.transform.FunctionPass)\n",
    "assert function_pass.info.opt_level == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "给定模块 `input_mod`，优化可以如下调用："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "#[version = \"0.0.5\"]\n",
       "def @main(%x: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {\n",
       "  log(%x) /* ty=Tensor[(10, 20), float32] */\n",
       "}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "updated_mod = function_pass(input_mod)\n",
    "updated_mod"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模块级 Pass\n",
    "\n",
    "模块级 Pass {func}`tvm.transform.module_pass` 与 {func}`~tvm.relay.transform.function_pass` 的定义和使用很相似。也分为类模式和函数模式两种。\n",
    "\n",
    "类模式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.transform.module_pass(opt_level=2)\n",
    "class CustomPipeline:\n",
    "    def __init__(self, enable_fold):\n",
    "        self.enable_fold = enable_fold\n",
    "        self.cse = relay.transform.EliminateCommonSubexpr()\n",
    "        self.const_fold = relay.transform.FoldConstant()\n",
    "\n",
    "    def transform_module(self, mod, ctx):\n",
    "        mod = self.cse(mod)\n",
    "        if self.enable_fold:\n",
    "            mod = self.const_fold(mod)\n",
    "        return mod\n",
    "\n",
    "# 创建定制的 pipeline 实例\n",
    "pipeline = CustomPipeline(enable_fold=False)\n",
    "assert isinstance(pipeline, tvm.transform.ModulePass)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def example():\n",
    "    shape = (1, 64, 54, 54)\n",
    "    c_data = np.empty(shape).astype(\"float32\")\n",
    "    c = relay.const(c_data)\n",
    "    weight = relay.var(\"weight\", shape=(64, 64, 3, 3))\n",
    "    x = relay.var(\"x\", relay.TensorType((1, 64, 56, 56), \"float32\"))\n",
    "    conv = relay.nn.conv2d(x, weight)\n",
    "    y = relay.add(c, c)\n",
    "    y = relay.multiply(y, relay.const(2, \"float32\"))\n",
    "    y = relay.add(conv, y)\n",
    "    z = relay.add(y, c)\n",
    "    z1 = relay.add(y, c)\n",
    "    z2 = relay.add(z, z1)\n",
    "    return relay.Function([x, weight], z2)\n",
    "\n",
    "m = IRModule.from_expr(example())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "运行 `pipeline`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_module = pipeline(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "函数模式："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.transform.module_pass(opt_level=2)\n",
    "def transform(mod, ctx):\n",
    "    x = relay.var(\"x\", shape=(2,), dtype=\"float32\")\n",
    "    func = relay.Function([x], relay.abs(x))\n",
    "    new_mod = IRModule()\n",
    "    new_mod['var'] = func\n",
    "    new_mod.update(mod)\n",
    "    return new_mod\n",
    "\n",
    "module_pass = transform\n",
    "assert isinstance(module_pass, tvm.transform.ModulePass)\n",
    "assert module_pass.info.opt_level == 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "fn (%x: Tensor[(2), float32]) {\n",
       "  abs(%x)\n",
       "}"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 给定模块 `m`，优化可以如下调用：\n",
    "updated_mod = module_pass(m)\n",
    "# 现在，函数 `abs` 应该被添加到模块 `m` 中。\n",
    "\n",
    "updated_mod[\"var\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 ('torch': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "20e538bd0bbffa4ce75068aaf85df10d4944f3fdb705eeec6781a4702773116f"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
